{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Train.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2dOaoFoIv-F2"
      },
      "source": [
        "# **ArtLine**\n",
        "**Create** **amazing** **lineart** **portraits** \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JN8RJVoKv8Lt"
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import fastai\n",
        "from fastai.vision import *\n",
        "from fastai.callbacks import *\n",
        "from fastai.vision.gan import *\n",
        "from torchvision.models import vgg16_bn\n",
        "from fastai.utils.mem import *\n",
        "from PIL import Image\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from torch.autograd import Variable\n",
        "import torchvision.transforms as transforms"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MYpkQrvncqB1"
      },
      "source": [
        "##**Edge Detection**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CqG1XipRciT5"
      },
      "source": [
        "def _gradient_img(img):\n",
        "    img = img.squeeze(0)\n",
        "    ten=torch.unbind(img)\n",
        "    x=ten[0].unsqueeze(0).unsqueeze(0)\n",
        "    \n",
        "    a=np.array([[1, 0, -1],[2,0,-2],[1,0,-1]])\n",
        "    conv1=nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    conv1.weight=nn.Parameter(torch.from_numpy(a).float().unsqueeze(0).unsqueeze(0))\n",
        "    G_x=conv1(Variable(x)).data.view(1,x.shape[2],x.shape[3])\n",
        "\n",
        "    b=np.array([[1, 2, 1],[0,0,0],[-1,-2,-1]])\n",
        "    conv2=nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    conv2.weight=nn.Parameter(torch.from_numpy(b).float().unsqueeze(0).unsqueeze(0))\n",
        "    G_y=conv2(Variable(x)).data.view(1,x.shape[2],x.shape[3])\n",
        "\n",
        "    G=torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2))\n",
        "    return G\n",
        "\n",
        "gradient = TfmPixel(_gradient_img)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OlGPepVInaQo"
      },
      "source": [
        "**PATH**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PwBuA4R0m2zf"
      },
      "source": [
        "path = Path('/content/gdrive/My Drive/Apdrawing')\n",
        "\n",
        "# Blended Facial Features\n",
        "\n",
        "path_hr = Path('/content/gdrive/My Drive/Apdrawing/draw tiny')\n",
        "path_lr = Path('/content/gdrive/My Drive/Apdrawing/Tiny Real')\n",
        "\n",
        "# Portrait Pair\n",
        "\n",
        "path_hr3 = Path('/content/gdrive/My Drive/Apdrawing/drawing')\n",
        "path_lr3= Path('/content/gdrive/My Drive/Apdrawing/Real')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q0wVDG4onf7G"
      },
      "source": [
        "**Architecture**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0zieELiLnSbr"
      },
      "source": [
        "arch = models.resnet34"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LvQrKkYkshF7"
      },
      "source": [
        "### **Facial Features**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5S0vl06RnTMv"
      },
      "source": [
        "src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.3, seed=42)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8tJphvDRnkO8"
      },
      "source": [
        "def get_data(bs,size):\n",
        "    data = (src.label_from_func(lambda x: path_hr/x.name)\n",
        "           .transform(get_transforms(xtra_tfms=[gradient()]), size=size, tfm_y=True)\n",
        "           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))\n",
        "\n",
        "    data.c = 3\n",
        "    return data"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yetcm3CHn0Dq"
      },
      "source": [
        "**64px**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k-Xoit4Mnvb-"
      },
      "source": [
        "bs,size=20,64"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-fqJZQfanmmV"
      },
      "source": [
        "data = get_data(bs,size)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4a-wREZlnpD9"
      },
      "source": [
        "data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JLOmfrvWn4Uy"
      },
      "source": [
        "t = data.valid_ds[0][1].data\n",
        "t = torch.stack([t,t])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "it5H7CoPn4q_"
      },
      "source": [
        "def gram_matrix(x):\n",
        "    n,c,h,w = x.size()\n",
        "    x = x.view(n, c, -1)\n",
        "    return (x @ x.transpose(1,2))/(c*h*w)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fzkaAicvn7Gi"
      },
      "source": [
        "gram_matrix(t)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XkORIHa4n-Fh"
      },
      "source": [
        "base_loss = F.l1_loss\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QYQnhpgxoAEF"
      },
      "source": [
        "vgg_m = vgg16_bn(True).features.cuda().eval()\n",
        "requires_grad(vgg_m, False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9CnEahFKoCPB"
      },
      "source": [
        "blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]\n",
        "blocks, [vgg_m[i] for i in blocks]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NwtR5uSAoFgR"
      },
      "source": [
        "class FeatureLoss(nn.Module):\n",
        "    def __init__(self, m_feat, layer_ids, layer_wgts):\n",
        "        super().__init__()\n",
        "        self.m_feat = m_feat\n",
        "        self.loss_features = [self.m_feat[i] for i in layer_ids]\n",
        "        self.hooks = hook_outputs(self.loss_features, detach=False)\n",
        "        self.wgts = layer_wgts\n",
        "        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))\n",
        "              ] + [f'gram_{i}' for i in range(len(layer_ids))]\n",
        "\n",
        "    def make_features(self, x, clone=False):\n",
        "        self.m_feat(x)\n",
        "        return [(o.clone() if clone else o) for o in self.hooks.stored]\n",
        "    \n",
        "    def forward(self, input, target):\n",
        "        out_feat = self.make_features(target, clone=True)\n",
        "        in_feat = self.make_features(input)\n",
        "        self.feat_losses = [base_loss(input,target)]\n",
        "        self.feat_losses += [base_loss(f_in, f_out)*w\n",
        "                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]\n",
        "        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3\n",
        "                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]\n",
        "        self.metrics = dict(zip(self.metric_names, self.feat_losses))\n",
        "        return sum(self.feat_losses)\n",
        "    \n",
        "    def __del__(self): self.hooks.remove()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Rwt3AAMDoGKv"
      },
      "source": [
        "feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2YtjntlCoIgo"
      },
      "source": [
        "wd = 1e-3\n",
        "y_range = (-3.,3.)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "faBs7PvAoQcY"
      },
      "source": [
        "def create_gen_learner():\n",
        "    return unet_learner(data, arch, wd=wd, blur=True,norm_type=NormType.Spectral,self_attention=True, y_range=(-3.0, 3.0),\n",
        "                        loss_func=feat_loss, callback_fns=LossMetrics)\n",
        "gc.collect();"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7UIo4hEwoS4o"
      },
      "source": [
        "learn_gen = create_gen_learner()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z_WqNg6NoU5I"
      },
      "source": [
        "learn_gen.lr_find()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l4OK0nMYoYRI"
      },
      "source": [
        "lr = 1-01\n",
        "epoch = 5\n",
        "def do_fit(save_name, lrs=slice(lr), pct_start=0.9):\n",
        "    learn_gen.fit_one_cycle(epoch, lrs, pct_start=pct_start,)\n",
        "    learn_gen.save(save_name)\n",
        "    learn_gen.show_results(rows=1, imgsize=5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D4EcXwfXoc2Q"
      },
      "source": [
        "do_fit('da', slice(lr))\n",
        "#lr*10"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gzoUTyG7pm-d"
      },
      "source": [
        "learn_gen.unfreeze()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4kIQ5qGYppbS"
      },
      "source": [
        "learn_gen.lr_find()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BB2tZvVzpt-s"
      },
      "source": [
        "epoch = 5\n",
        "do_fit('db', slice(1E-2))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FtEI1s2BqFz2"
      },
      "source": [
        "**128px**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XqKWCsBgp236"
      },
      "source": [
        "data = get_data(8,128)\n",
        "learn_gen.data = data\n",
        "learn_gen.freeze()\n",
        "gc.collect()\n",
        "learn_gen.load('db');"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RlFdd8pOp9N1"
      },
      "source": [
        "epoch =5\n",
        "lr = 1E-03\n",
        "do_fit('db2',slice(lr))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bDFMDymEqDDm"
      },
      "source": [
        "learn_gen.unfreeze()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "55vcLnbrqK2p"
      },
      "source": [
        "epoch = 5\n",
        "do_fit('db3', slice(1e-02,1e-5), pct_start=0.3)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GMSbEIC8qYv1"
      },
      "source": [
        "**192px**\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BLmzmoJ7qTJy"
      },
      "source": [
        "data = get_data(5,192)\n",
        "learn_gen.data = data\n",
        "learn_gen.freeze()\n",
        "gc.collect()\n",
        "learn_gen.load('db3');"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VdxT7G9Wqdud"
      },
      "source": [
        "epoch =5\n",
        "lr = 1E-06\n",
        "do_fit('db4')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gABZH1FwqjeG"
      },
      "source": [
        "learn_gen.unfreeze()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WqWh1dLTql8z"
      },
      "source": [
        "epoch = 5\n",
        "do_fit('db5', slice(1e-06,1e-4), pct_start=0.3)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I_YqXNzjq8m3"
      },
      "source": [
        "# **Portraits**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pD6ognJgqqd6"
      },
      "source": [
        "src = ImageImageList.from_folder(path_lr3).split_by_rand_pct(0.2, seed=42)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uFws14t6rA6y"
      },
      "source": [
        "def get_data(bs,size):\n",
        "    data = (src.label_from_func(lambda x: path_hr3/x.name)\n",
        "           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)\n",
        "           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))\n",
        "\n",
        "    data.c = 3\n",
        "    return data"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45kYzAwvrL9T"
      },
      "source": [
        "**128px**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ish7mEzdrEDa"
      },
      "source": [
        "data = get_data(8,128)\n",
        "learn_gen.data = data\n",
        "learn_gen.freeze()\n",
        "gc.collect()\n",
        "learn_gen.load('db5');"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HdUSomYArTNH"
      },
      "source": [
        "data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zAlhXL46re94"
      },
      "source": [
        "learn_gen.lr_find()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ADT1-uoyrhbA"
      },
      "source": [
        "epoch = 5\n",
        "lr = 1e-03\n",
        "do_fit('db6')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "whapmkF0rmHu"
      },
      "source": [
        "learn_gen.unfreeze()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tdO4W55TrqQE"
      },
      "source": [
        "epoch = 5\n",
        "do_fit('db7', slice(6.31E-07,1e-5), pct_start=0.3)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WMHRoP1Vr5zA"
      },
      "source": [
        "**192px**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B_YCV5tYryyV"
      },
      "source": [
        "data = get_data(4,192)\n",
        "learn_gen.data = data\n",
        "learn_gen.freeze()\n",
        "gc.collect()\n",
        "learn_gen.load('db7');"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e0buABN-r82o"
      },
      "source": [
        "learn_gen.lr_find()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "crNFdzojr_vS"
      },
      "source": [
        "epoch = 5\n",
        "lr = 4.37E-05\n",
        "do_fit('db8')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mRJ_954VsDW7"
      },
      "source": [
        "learn_gen.unfreeze()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mr8mvuhVsGk2"
      },
      "source": [
        "epoch = 5\n",
        "do_fit('db9', slice(1.00E-05,1e-3), pct_start=0.3)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
